# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from typing import Optional, Union

import pytest
import tvm
import tvm.script
import tvm.testing
from tvm import IRModule, relax, tir, topi
from tvm.ir import VDevice, DummyGlobalInfo
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T


def _check(
    parsed: Union[relax.Function, IRModule],
    expect: Optional[Union[relax.Function, IRModule]] = None,
):
    test = parsed.script(show_meta=True)
    roundtrip_mod = tvm.script.from_source(test)
    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
    if isinstance(parsed, IRModule) and isinstance(roundtrip_mod, IRModule):
        assert relax.analysis.well_formed(parsed)
        assert relax.analysis.well_formed(roundtrip_mod)
    if expect:
        tvm.ir.assert_structural_equal(parsed, expect)


def test_simple_func():
    @R.function
    def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
        R.func_attr({"Primitive": True})
        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))
        gv1 = R.call_dps_packed("extern_dps_func", gv0, R.Tensor((128, 128), dtype="float32"))
        return gv1

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,), attrs={"Primitive": True}):
        y = bb.emit(relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")))
        out = bb.emit(
            relax.call_dps_packed("extern_dps_func", y, R.Tensor((128, 128), dtype="float32"))
        )
        bb.emit_func_output(out)

    _check(foo, bb.get()["foo"])


def test_error_report():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
            # error: a = b = c is not allowed.
            gv0 = gv1 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))
            return gv0


def test_mismatch_cast_dims_and_ndim():
    with pytest.raises(Exception):

        @R.function
        def f(
            x: R.Tensor((2, 3), "float32", ndim=3),
        ):  # error: ndim and the shape dims are mismatch
            return x


def test_unexpected_num_kw_args():
    with pytest.raises(Exception):

        @R.function
        def f(x: R.Tensor(dtype="float32", ndim=1, foo=2)):  # error: unexpected kw args foo
            return x


def test_unexpected_ndim():
    with pytest.raises(Exception):

        @R.function
        # error: dim is expected to be non-negative int or -1 for unknown
        def f(x: R.Tensor(dtype="float32", ndim=-2)):
            return x


def test_unexpected_ndim_type():
    with pytest.raises(Exception):

        @R.function
        def f(x: R.Tensor(dtype="float32", ndim="1")):  # error: dim is expected to be int
            return x


def test_unexpected_tir_cast_args():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(x: R.Tensor(("m",), "float32")):
            m = T.int64()
            # tir.cast expects 2 arguments, but got 3
            return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32"))


def test_unexpected_tir_args():
    with pytest.raises(tvm.error.DiagnosticError):

        @tvm.script.ir_module
        class TestWellCallTIR:
            @T.prim_func
            def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None:
                T.func_attr(({"global_symbol": "tir_addone"}))
                for i, j in T.grid(16, 16):
                    with T.block("tir_addone"):
                        vi, vj = T.axis.remap("SS", [i, j])
                        B[vi, vj] = A[vi, vj] + T.int32(1)

            @R.function
            def foo(x: R.Tensor(("m", "m"), "float32")):
                m = T.int64()
                # tir.max expects 2 arguments, but got 1
                gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), dtype="float32"))
                return gv

    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(x: R.Tensor(("m", "n"), "float32")):
            m = T.int64()
            # call_tir expected a tir prim_func
            return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), dtype="float32"))


def test_func_type_annotation_fail():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(x, y):  # error: the parameter type annotation is missing
            z = R.add(x, y)
            y = z
            return y


def test_if_mismatch_var_fail():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")):
            if cond:
                w = R.add(x, x)
                y = R.multiply(w, w)
            else:
                w = R.multiply(x, x)
                z = R.add(w, w)  # error: The binding var is expected to `y`
            return z


def test_unassigned_call_fail():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(x: R.Tensor):
            R.add(x, x)
            return x


def test_incorrect_tensor_shape():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(x: R.Tensor([16])):
            y: R.Tensor(16) = R.add(x, x)
            return y


def test_simple_module():
    @I.ir_module
    class TestModule:
        @T.prim_func(private=True)
        def tir_func(
            x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
            y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
        ):
            T.func_attr({"tir.noalias": True})
            for i, j in T.grid(T.int64(128), T.int64(128)):
                with T.block():
                    vi, vj = T.axis.remap("SS", [i, j])
                    y[vi, vj] = x[vi, vj] + 1.0

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
            cls = TestModule
            gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32"))
            return gv0

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,), {"global_symbol": "foo"}):
        out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
        bb.emit_func_output(out)

    _check(TestModule, bb.get())


def test_emit_te_primfunc_attrs():
    @I.ir_module
    class TestModule:
        @T.prim_func(private=True)
        def plus_one(
            x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
            y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
        ):
            T.func_attr({"some_attr": "foo", "another_attr": True, "tir.noalias": True})
            for i, j in T.grid(T.int64(128), T.int64(128)):
                with T.block():
                    vi, vj = T.axis.remap("SS", [i, j])
                    y[vi, vj] = x[vi, vj] + 1.0

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
            cls = TestModule
            gv0 = R.call_tir(cls.plus_one, x, R.Tensor((128, 128), dtype="float32"))
            return gv0

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,), {"global_symbol": "foo"}):
        out = bb.emit_te(
            lambda x: x + 1,
            x,
            primfunc_name_hint="plus_one",
            primfunc_attrs={"some_attr": "foo", "another_attr": True},
        )
        bb.emit_func_output(out)
    _check(TestModule, bb.get())


def test_emit_te():
    @I.ir_module
    class EmitTE:
        @R.function
        def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20), dtype="float32"):
            lv1 = R.emit_te(topi.add, x, x)
            out = R.emit_te(topi.multiply, lv1, lv1)
            return out

    bb = relax.BlockBuilder()
    x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32"))
    with bb.function("main", [x], {"global_symbol": "main"}):
        lv1 = bb.emit_te(topi.add, x, x)
        out = bb.emit_te(topi.multiply, lv1, lv1)
        bb.emit_func_output(out)

    _check(EmitTE, bb.get())


def test_module_with_attr_and_global_info():
    @I.ir_module
    class TestModule:
        I.module_attrs({"attr": 10})
        I.module_global_infos(
            {
                "dummy": [
                    I.dummy_global_info(),  # dummy[0]
                    I.dummy_global_info(),  # dummy[1]
                ]
            }
        )

        @T.prim_func(private=True)
        def tir_func(
            x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
            y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
        ):
            T.func_attr({"tir.noalias": True})
            for i, j in T.grid(T.int64(128), T.int64(128)):
                with T.block():
                    vi, vj = T.axis.remap("SS", [i, j])
                    y[vi, vj] = x[vi, vj] + 1.0

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
            cls = TestModule
            gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32"))
            return gv0

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,), {"global_symbol": "foo"}):
        out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
        bb.emit_func_output(out)
    mod = bb.get()
    mod.update_global_info("dummy", [DummyGlobalInfo(), DummyGlobalInfo()])
    mod = mod.with_attr("attr", 10)
    _check(TestModule, mod)


def test_global_info_vdevice():
    vdevices = [
        VDevice("llvm"),
        VDevice("cuda", 0),
        VDevice("cuda -arch=sm_80", 0),
        VDevice("metal", 0, "global"),
    ]

    @I.ir_module
    class TestModule:
        I.module_attrs({"attr": 10})
        I.module_global_infos(
            {
                "vdevice": [
                    I.vdevice("llvm"),
                    I.vdevice("cuda", 0),
                    I.vdevice("cuda -arch=sm_80", 0),
                    I.vdevice("metal", 0, "global"),
                ]
            }
        )

        @T.prim_func(private=True)
        def tir_func(
            x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
            y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
        ):
            T.func_attr({"tir.noalias": True})
            for i, j in T.grid(T.int64(128), T.int64(128)):
                with T.block():
                    vi, vj = T.axis.remap("SS", [i, j])
                    y[vi, vj] = x[vi, vj] + 1.0

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
            cls = TestModule
            gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32"))
            return gv0

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
        bb.emit_func_output(out)
    mod = bb.get()
    mod.update_global_info("vdevice", vdevices)
    mod = mod.with_attr("attr", 10)
    _check(TestModule, mod)


def test_relax_tensor_op():
    @R.function
    def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):
        y = R.add(x, x)
        z = R.multiply(x, y)
        return z

    x = relax.Var("x", R.Tensor((4, 4), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        y = bb.emit(relax.op.add(x, x))
        z = bb.emit(relax.op.multiply(x, y))
        bb.emit_func_output(z)

    _check(foo, bb.get()["foo"])


def test_relax_base_op():
    @R.function
    def foo(x: R.Tensor((4, 4), "float32")):
        alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32")
        shape = R.shape_of(alloc)
        return shape

    x = relax.Var("x", R.Tensor((4, 4), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0))
        shape = bb.emit(relax.op.shape_of(alloc))
        bb.emit_func_output(shape)

    _check(foo, bb.get()["foo"])


def test_relax_shape_to_tensor():
    @R.function
    def foo(x: R.Shape((4, 4))):
        tensor = R.shape_to_tensor(x)
        return tensor

    x = relax.Var("x", R.Shape((4, 4)))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        tensor = bb.emit(relax.op.shape_to_tensor(x))
        bb.emit_func_output(tensor)

    _check(foo, bb.get()["foo"])


def test_symbolic_shape():
    @R.function
    def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
        m = T.int64()
        n = T.int64()
        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32"))
        return gv0

    @R.function
    def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
        m = T.int64()
        n = T.int64()
        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32"))
        return gv0

    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2):
            m = T.int64()
            n = T.int32()  # The shape dtype should be int64
            gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32"))
            return gv0

    def _expected(name: str):
        n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
        x = relax.Var("x", R.Tensor([m, n], "float32"))
        bb = relax.BlockBuilder()
        with bb.function(name, (x,)):
            out = bb.emit(
                relax.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32"))
            )
            bb.emit_func_output(out)
        return bb.get()[name]

    _check(foo, _expected("foo"))
    _check(bar, _expected("bar"))


def test_shadowing():
    @R.function
    def foo(x: R.Tensor((4, 4), "float32")):
        y = R.add(x, x)
        z = R.multiply(x, y)
        y = R.add(x, y)
        y = z
        y = R.multiply(y, x)
        z = y
        return z

    x = relax.Var("x", R.Tensor((4, 4), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        y = bb.emit(relax.op.add(x, x))
        z = bb.emit(relax.op.multiply(x, y))
        y = bb.emit(relax.op.add(x, y))
        y = bb.emit(z)
        y = bb.emit(relax.op.multiply(y, x))
        z = bb.emit(y)
        bb.emit_func_output(z)

    _check(foo, bb.get()["foo"])


def test_match_cast():
    @R.function
    def foo(x: R.Tensor("float32"), y: R.Tensor("float32")):
        m = T.int64()
        n = T.int64()
        x0 = R.match_cast(x, R.Tensor([m], "float32"))
        with R.dataflow():
            y0 = R.match_cast(y, R.Tensor([n], "float32"))
            gv = y0
            R.output(gv)
        return (x0, R.shape([m, n * 2]))

    x = relax.Var("x", R.Tensor("float32"))
    y = relax.Var("y", R.Tensor("float32"))
    m = tir.Var("m", dtype="int64")
    n = tir.Var("n", dtype="int64")
    y2 = relax.Var("y", R.Tensor([n], "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x, y)):
        x0 = bb.match_cast(x, R.Tensor([m], "float32"))
        with bb.dataflow():
            y0 = bb.match_cast(y, R.Tensor([n], "float32"))
            bb.emit_output(y0)
        bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([m, n * 2])]))

    _check(foo, bb.get()["foo"])


def test_tuple_return():
    @R.function
    def foo(x: R.Tensor((4, 4), "float32")):
        gv0 = R.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))
        gv1 = R.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))
        return (gv0, gv1)

    x = relax.Var("x", R.Tensor((4, 4), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        gv0 = bb.emit(relax.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32")))
        gv1 = bb.emit(relax.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32")))
        bb.emit_func_output(relax.Tuple((gv0, gv1)))

    _check(foo, bb.get()["foo"])


def test_tuple_return_2():
    @R.function
    def foo(x: R.Tensor("float32", ndim=2)):
        n, m = T.int64(), T.int64()
        x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
        return (x0, R.shape([n + 1, m, 1]))

    x = relax.Var("x", R.Tensor("float32", ndim=2))
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        x0 = bb.match_cast(x, R.Tensor((n, m), "float32"))
        bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([n + 1, m, 1])]))

    _check(foo, bb.get()["foo"])


def test_tuple_binding():
    @R.function
    def foo(x: R.Tensor("float32", ndim=2)):
        n, m = T.int64(), T.int64()
        x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
        t0 = (x, x0)
        t1 = (x, R.shape([n, m]), t0)
        return t1

    x = relax.Var("x", R.Tensor("float32", ndim=2))
    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        x0 = bb.match_cast(x, R.Tensor((n, m), "float32"))
        t0 = bb.emit(relax.Tuple([x, x0]))
        t1 = bb.emit(relax.Tuple([x, relax.ShapeExpr([n, m]), t0]))
        bb.emit_func_output(t1)

    _check(foo, bb.get()["foo"])


def test_tuple_get_item():
    @R.function
    def foo(x: R.Tensor, y: R.Tensor):
        t1 = R.tuple(x, y)
        t2 = (x, y)
        a = t1[0]
        b = R.TupleGetItem(t2, 1)
        c = R.add(a, b)
        return c

    x = relax.Var("x", R.Tensor())
    y = relax.Var("y", R.Tensor())
    bb = relax.BlockBuilder()
    with bb.function("foo", (x, y)):
        t1 = bb.emit(relax.Tuple([x, y]))
        t2 = bb.emit(relax.Tuple([x, y]))
        a = bb.emit(relax.TupleGetItem(t1, 0))
        b = bb.emit(relax.TupleGetItem(t2, 1))
        c = bb.emit(relax.op.add(a, b))
        bb.emit_func_output(c)

    _check(foo, bb.get()["foo"])


def test_dataflow_block():
    @R.function
    def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
        with R.dataflow():
            lv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))
            lv1 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))
            gv = lv1
            R.output(gv)
        return gv

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        with bb.dataflow():
            lv0 = bb.emit(
                relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))
            )
            lv1 = bb.emit(
                relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))
            )
            gv = bb.emit_output(lv1)
        bb.emit_func_output(gv)

    _check(foo, bb.get()["foo"])


def test_dataflow_block_advanced():
    @R.function
    def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
        gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))
        gv1 = R.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))
        with R.dataflow():
            m = T.int64()
            n = T.int64()
            lv0 = R.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32"))
            lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32"))
            gv2 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))
            gv2 = R.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32"))
            gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32"))
            gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32"))
            gv4 = gv3
            gv5 = gv2
            R.output(gv5, gv4)
        gv6 = R.call_dps_packed("extern_func", gv5, R.Tensor((128, 128), dtype="float32"))
        gv7 = R.call_dps_packed("extern_func", gv6, R.Tensor((128, 128), dtype="float32"))
        return gv7

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    m = tir.Var("m", dtype="int64")
    n = tir.Var("n", dtype="int64")
    with bb.function("foo", (x,)):
        gv0 = bb.emit(
            relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))
        )
        gv1 = bb.emit(
            relax.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))
        )
        with bb.dataflow():
            lv0 = bb.emit(
                relax.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32"))
            )
            lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32"))
            gv2 = bb.emit(
                relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))
            )
            gv21 = bb.emit(
                relax.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32"))
            )
            gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32"))
            gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32"))
            gv32 = bb.emit_output(gv31)
            gv22 = bb.emit_output(gv21)
        gv4 = bb.emit(
            relax.call_dps_packed("extern_func", gv22, R.Tensor((128, 128), dtype="float32"))
        )
        gv5 = bb.emit(
            relax.call_dps_packed("extern_func", gv4, R.Tensor((128, 128), dtype="float32"))
        )
        bb.emit_func_output(gv5)

    _check(foo, bb.get()["foo"])


def test_dataflow_binding_after_output():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
            with R.dataflow():
                gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))
                R.output(gv)
                lv = R.call_tir("extern_func", gv, R.Tensor((128, 128), dtype="float32"))
            return gv


def test_dataflow_output_global_var():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
            gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))
            with R.dataflow():
                gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))
                R.output(gv0, gv1)
            return gv1


def test_dataflow_multiple_output():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
            with R.dataflow():
                gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))
                R.output(gv)
                R.output(gv)
            return gv


def test_dataflow_output_outside_dataflow_block():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
            gv = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))
            R.output(gv)
            return gv


def test_dataflow_scope_fail():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(x: R.Tensor(ndim=2)):
            with R.dataflow():
                y = R.add(x, x)
                z = R.multiply(y, x)
                w = R.add(z, x)
                R.output(y, w)
            t = R.multiply(y, z)  # z is not in the outer scope
            return t


def test_return_without_binding():
    @R.function
    def foo(x: R.Tensor((128, 128), "float32")):
        return x

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        bb.emit_func_output(x)

    _check(foo, bb.get()["foo"])


def test_multiple_return():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")):
            return x
            return x


def test_function_without_return():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor((128, 128), "float32")):
            gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))


def test_tensor_type_without_args():
    @R.function
    def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        v = R.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32"))
        return v

    x = relax.Var("x", R.Tensor((32, 32), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x)):
        v = bb.emit(relax.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32")))
        bb.emit_func_output(v)

    _check(foo, bb.get()["foo"])


def test_tensor_with_vdevice():
    vdevices = [
        VDevice("llvm"),
        VDevice("cuda", 0),
        VDevice("metal", 0, "global"),
        VDevice("cuda -arch=sm_80", 0),
    ]

    @I.ir_module
    class TestModule:
        I.module_attrs({"attr": 10})
        I.module_global_infos(
            {
                "vdevice": [
                    I.vdevice("llvm"),
                    I.vdevice("cuda", 0),
                    I.vdevice("metal", 0, "global"),
                    I.vdevice("cuda -arch=sm_80", 0),
                ]
            }
        )

        @R.function
        def foo(
            a: R.Tensor((128, 128), "float32", "cuda:1"),  # noqa: F722
            b: R.Tensor((128, 128), "float32", "llvm"),
            c: R.Tensor((128, 128), "float32", "vdevice:3"),  # noqa: F722
        ) -> R.Tensor((128, 128), "float32", "cuda:1"):  # noqa: F722
            s = R.add(a, c)
            return s

    a = relax.Var("a", R.Tensor((128, 128), "float32", vdevices[3]))
    b = relax.Var("b", R.Tensor((128, 128), "float32", vdevices[0]))
    c = relax.Var("c", R.Tensor((128, 128), "float32", vdevices[3]))
    bb = relax.BlockBuilder()
    with bb.function("foo", (a, b, c)):
        out = bb.emit(relax.op.add(a, c))
        bb.emit_func_output(out)
    mod = bb.get()
    mod = mod.with_attr("attr", 10)
    mod.update_global_info("vdevice", vdevices)

    _check(TestModule, mod)


def test_direct_return():
    @R.function
    def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"):
        return x

    x = relax.Var("x", R.Tensor((32, 32), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x)):
        bb.emit_func_output(x)

    _check(foo, bb.get()["foo"])


def test_call_packed():
    @R.function(pure=False)
    def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        z = R.call_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32"))
        return z

    x = relax.Var("x", R.Tensor((32, 32), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x), pure=False):
        z = bb.emit(
            relax.Call(
                relax.ExternFunc("vm.builtin.copy"),
                (x,),
                None,
                sinfo_args=[R.Tensor((32, 32), "float32")],
            )
        )
        bb.emit_func_output(z)

    _check(foo, bb.get()["foo"])


def test_call_packed_without_sinfo_args():
    @R.function(pure=False)
    def foo(x: R.Object) -> R.Object:
        z = R.call_packed("test", x)
        return z

    x = relax.Var("x", R.Object())
    bb = relax.BlockBuilder()
    with bb.function("foo", (x), pure=False):
        z = bb.emit(
            relax.Call(
                relax.ExternFunc("test"),
                (x,),
                None,
                sinfo_args=[],
            )
        )
        bb.emit_func_output(z)

    _check(foo, bb.get()["foo"])


def test_annotation():
    @R.function(pure=False)
    def foo(
        x: R.Tensor((32, "m"), "float32"),
        y: R.Tensor(("m",), "float32"),
        r: R.Tensor(dtype="int64"),
    ) -> R.Object:
        m = T.int64()
        z: R.Tensor((32, m), "float32") = R.multiply(x, y)
        w: R.Tensor(ndim=2) = R.multiply(z, z)
        q: R.Tensor = R.add(w, w)
        t = R.add(w, z)
        sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape)
        lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh)
        o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object)
        return o

    def _check_struct_info(binding, expected_sinfo):
        tvm.ir.assert_structural_equal(binding.var.struct_info, expected_sinfo)
        tvm.ir.assert_structural_equal(binding.value.struct_info, expected_sinfo)

    # Cannot use block builder here because we need to check the annotated type,
    # which may be inconsistent with deduced type.
    assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo)
    m = relax.get_shape_of(foo.params[0])[1]
    bindings = foo.body.blocks[0].bindings
    sh = bindings[4].var

    _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32"))
    _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=2))
    _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=-1))
    _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=2))
    _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1))
    _check_struct_info(bindings[5], relax.TensorStructInfo(sh))
    _check_struct_info(bindings[6], relax.ObjectStructInfo())


def test_annotate_override():
    @R.function
    def foo(x: R.Tensor):
        y = x
        # z will be treated as object type even though it's a tensor
        z: R.Object = R.add(x, y)
        return z

    assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo)
    y_bind, z_bind = foo.body.blocks[0].bindings
    assert isinstance(y_bind.var.struct_info, relax.TensorStructInfo)
    assert isinstance(z_bind.var.struct_info, relax.ObjectStructInfo)

    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def test(x: R.Tensor):
            # Error: x is of Tensor StructInfo, which can not annotate to R.Shape.
            z: R.Shape = x
            return z

    @R.function
    def bar(x: R.Tensor):
        # x is of Tensor StructInfo, the annotation of `z` is ignored.
        z: R.Object = x
        return z

    assert isinstance(bar.ret_struct_info, relax.TensorStructInfo)
    (z_bind,) = bar.body.blocks[0].bindings
    assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo)


def test_call_dps_packed_empty_shape():
    @R.function
    def foo(x: R.Tensor((), "float32")):
        z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32"))
        return z

    (z_bind,) = foo.body.blocks[0].bindings
    shape_expr = z_bind.value.sinfo_args[0].shape

    assert isinstance(shape_expr, relax.ShapeExpr)
    assert len(shape_expr.values) == 0


def test_call_tir_empty_tuple_arg():
    bb = relax.BlockBuilder()
    dummy_param = relax.Var("dummy_param", R.Tensor(()))
    with bb.function("foo", [dummy_param], {"global_symbol": "foo"}):
        output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", fill_value=1.0)
        bb.emit_func_output(output)

    _check(bb.get())


def test_call_tir_with_tir_var():
    @I.ir_module
    class Module:
        @R.function
        def main(
            dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",), "float32")
        ) -> R.Tensor(("n * 2",), "float32"):
            n = T.int64()
            cls = Module
            y = R.call_tir(cls.copy, x, R.Tensor((n * 2,), dtype="float32"), tir_vars=(n,))
            return y

        @T.prim_func
        def copy(var_x: T.handle, var_y: T.handle, n: T.int64):
            X = T.match_buffer(var_x, (n * 2,), dtype="float32")
            Y = T.match_buffer(var_y, (n * 2,), dtype="float32")
            for i in T.grid(n * 2):
                with T.block("block"):
                    vi = T.axis.remap("S", [i])
                    Y[vi] = X[vi]

    _check(Module)


def test_call_tir_with_grad():
    @I.ir_module
    class Module:
        @T.prim_func
        def identity_tir(a: T.handle, b: T.handle) -> None:
            A = T.match_buffer(a, [54, 96])
            B = T.match_buffer(b, [54, 96])

            for i, j in T.grid(54, 96):
                with T.block("compute"):
                    vi, vj = T.axis.remap("SS", [i, j])
                    B[vi, vj] = A[vi, vj]

        @R.function
        def main(v0: R.Tensor([54, 96], "float32")):
            cls = Module
            out = R.call_tir_with_grad(
                cls.identity_tir,
                (v0,),
                R.Tensor((54, 96), "float32"),
                te_grad_name="identity_k_grad",
                te_grad_kwargs={"k": 1.0},
            )
            return out

    _check(Module)


def test_call_tir_inplace():
    @tvm.script.ir_module
    class Module:
        @T.prim_func
        def copy(
            A: T.Buffer((2, 3), "int32"),
            B: T.Buffer((2, 3), "int32"),
            out1: T.Buffer((2, 3), "int32"),
        ):
            # copies the contents of B into A and out1
            T.func_attr({"tir.noalias": True})
            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
                with T.block("T_zeros"):
                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
                    T.reads(B[ax0, ax1])
                    T.writes(A[ax0, ax1], out1[ax0, ax1])
                    A[ax0, ax1] = B[ax0, ax1]
                    out1[ax0, ax1] = B[ax0, ax1]

        @R.function
        def main(
            x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")
        ) -> R.Tuple(
            R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")
        ):
            res = R.call_tir_inplace(
                Module.copy,
                (x, y),
                [0, -1],
                [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")],
            )
            return res

    _check(Module)


def test_call_tir_inplace_with_tuple_var_raises_error():
    with pytest.raises(tvm.error.DiagnosticError):

        @tvm.script.ir_module
        class Module:
            @R.function
            def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")):
                cls = Module
                args = (x, y)
                res = R.call_tir_inplace(
                    cls.copy,
                    # The `args` tuple must be an in-line tuple, not a
                    # reference to a tuple.  This error should be
                    # caught and raised during parsing.
                    args,
                    inplace_indices=[0, -1],
                    out_sinfo=[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")],
                )
                return res

            @T.prim_func
            def copy(
                A: T.Buffer((2, 3), "int32"),
                B: T.Buffer((2, 3), "int32"),
                out1: T.Buffer((2, 3), "int32"),
            ):
                # copies the contents of B into A and out1
                T.func_attr({"tir.noalias": True})
                for iters in T.grid(T.int64(2), T.int64(3)):
                    with T.block("T_zeros"):
                        i, j = T.axis.remap("SS", iters)
                        A[i, j] = B[i, j]
                        out1[i, j] = B[i, j]


def test_local_function():
    @R.function
    def main(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32"):
        @R.function
        def outer_func(
            c1: R.Tensor((2, 3), "float32")
        ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)):
            @R.function
            def inner_func(x1: R.Tensor((2, 3), "float32")):
                s: R.Tensor((2, 3), "float32") = R.add(x1, c1)
                return s

            return inner_func

        in_call = outer_func(x)
        res = in_call(y)
        return res

    main_bindings = main.body.blocks[0].bindings
    assert len(main_bindings) == 3
    outer_func = main_bindings[0].value
    assert isinstance(outer_func, relax.Function)

    outer_func_bindings = outer_func.body.blocks[0].bindings
    assert len(outer_func_bindings) == 1
    inner_func = outer_func_bindings[0].value
    assert isinstance(inner_func, relax.Function)


def test_inline_prim_func():
    with pytest.raises(tvm.error.DiagnosticError):

        @I.ir_module
        class TestModule:
            @R.function
            def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")):
                @T.prim_func
                def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
                    A = T.match_buffer(a, (128, 128))
                    B = T.match_buffer(b, (128, 128))
                    C = T.match_buffer(c, (128, 128))

                    for i, j, k in T.grid(128, 128, 128):
                        with T.block():
                            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                            with T.init():
                                C[vi, vj] = 0.0
                            C[vi, vj] += A[vi, vk] * B[vj, vk]

                z = relax.call_tir(my_matmul, (x, y), R.Tensor((128, 128), dtype="float32"))
                return z


def test_cross_function_call():
    @I.ir_module
    class Mod0:
        @R.function
        def foo(x: R.Tensor((10, 5), "float32")):
            s = R.add(x, x)
            return s

        @R.function
        def main(x: R.Tensor((10, 5), "float32")):
            cls = Mod0
            inner = cls.foo
            gv1 = inner(x)
            gv2 = Mod0.foo(x)
            return (inner, gv1, gv2)

    @I.ir_module
    class Mod1:
        @R.function
        def main(x: R.Tensor((10, 5), "float32")):
            cls = Mod1
            inner = cls.foo
            gv1 = inner(x)
            gv2 = Mod1.foo(x)
            return (inner, gv1, gv2)

        @R.function
        def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"):
            s = R.add(x, x)
            return s


def test_if_branch():
    @R.function
    def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor((1,), "float32"):
        if cond:
            w = R.add(x, x)
            y = R.multiply(w, w)
        else:
            w = R.multiply(x, x)
            y = R.add(w, w)
        return y

    cond, x = foo.params
    y_bind = foo.body.blocks[0].bindings[0]
    y, ite = y_bind.var, y_bind.value

    assert isinstance(y, relax.Var)
    assert y.name_hint == "y"

    assert isinstance(ite, relax.If)
    assert isinstance(ite.true_branch, relax.SeqExpr)
    assert isinstance(ite.false_branch, relax.SeqExpr)

    def check_call(call, op, args):
        assert isinstance(call, relax.Call)
        if isinstance(op, str):
            assert call.op.name == op
        else:
            assert call.op == op
        tvm.ir.assert_structural_equal(call.args, args)

    w_bind = ite.true_branch.blocks[0].bindings[0]
    # the seq exprts in the branches are normalized to bind any call
    # in the seq expr "body" to a var
    y_bind = ite.true_branch.blocks[-1].bindings[-1]
    assert w_bind.var.name_hint == "w"
    check_call(w_bind.value, "relax.add", [x, x])
    check_call(y_bind.value, "relax.multiply", [w_bind.var, w_bind.var])

    w_bind = ite.false_branch.blocks[0].bindings[0]
    y_bind = ite.false_branch.blocks[-1].bindings[-1]
    assert w_bind.var.name_hint == "w"
    check_call(w_bind.value, "relax.multiply", [x, x])
    check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var])


def test_if_branch_with_match_cast():
    """The last branch of a relax::If node may be a MatchCast

    This is a regression test.  In previous implementations, using
    R.match_cast as the last binding would cause a segfault while
    parsing.
    """

    @R.function
    def func(A: R.Tensor([16, 16]), is_bfloat16: R.Prim("bool")):
        if is_bfloat16:
            A = R.match_cast(A, R.Tensor([16, 16], "bfloat16"))
            B = A.astype("float16")
        else:
            B = R.match_cast(A, R.Tensor([16, 16], "float16"))
        return B

    A, is_bfloat16 = func.params
    (block,) = func.body.blocks
    (B_binding,) = block.bindings

    B_var = B_binding.var
    assert isinstance(B_var, relax.Var)
    assert B_var.name_hint == "B"

    if_then_else = B_binding.value
    assert isinstance(if_then_else, relax.If)
    assert isinstance(if_then_else.true_branch, relax.SeqExpr)
    assert isinstance(if_then_else.false_branch, relax.SeqExpr)

    else_branch = if_then_else.false_branch
    (else_block,) = else_branch.blocks

    assert isinstance(else_block.bindings[-1], relax.MatchCast)

    # If the `R.match_cast` were removed, the function would infer the
    # return value as `R.Tensor([16,16])`, with an unknown dtype.
    # With the `R.match_cast` retained, the output dtype is known.
    tvm.ir.assert_structural_equal(func.ret_struct_info, R.Tensor([16, 16], "float16"))


def test_if_inside_dataflow():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")):
            with R.dataflow():
                if cond:
                    w = R.add(x, x)
                    y = R.multiply(w, w)
                else:
                    w = R.multiply(x, x)
                    y = R.add(w, w)
                R.output(y)
            return y


def test_var_if_scoping_fail():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def f(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")):
            if cond:
                w = R.add(x, x)
                y = R.multiply(w, w)
            else:
                w = R.multiply(x, x)
                y = R.add(w, w)
            return w  # error: The w is not defined in the outer scope


def test_if_branch_var_scope():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")):
            if cond:
                w = R.add(x, x)
                y = R.multiply(w, w)
            else:
                w = R.multiply(x, x)
                y = R.add(w, w)
            return w


def test_scalar_tensor_as_branch_condition():
    """Branch condition can be 0-d tensor"""

    @R.function
    def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")):
        if cond:
            out = R.add(x, x)
        else:
            out = R.multiply(x, x)
        return out

    if_else = func.body.blocks[0].bindings[0].value
    assert isinstance(if_else.cond, relax.Var)
    tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Tensor([], "bool"))


def test_prim_value_as_branch_condition():
    """In addition to scalar tensor, can use R.Prim condition"""

    @R.function
    def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")):
        if cond:
            out = R.add(x, x)
        else:
            out = R.multiply(x, x)
        return out

    if_else = func.body.blocks[0].bindings[0].value
    assert isinstance(if_else.cond, relax.Var)
    tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim("bool"))


def test_computed_prim_value_as_branch_condition():
    """The R.Prim condition may be computed within the function"""

    @R.function
    def func(x: R.Tensor(["N"], "float32")):
        N = T.int64()
        if R.prim_value(N % 16 == 0):
            out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info])
        else:
            out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info])
        return out

    N = func.params[0].struct_info.shape[0]
    if_else = func.body.blocks[0].bindings[0].value
    assert isinstance(if_else.cond, relax.PrimValue)
    tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value)
    tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim(value=N % 16 == 0))


def test_tir_expr_as_branch_condition():
    """Syntactic sugar, wrap PrimExpr as PrimValue"""

    @R.function(private=True)
    def sugared(x: R.Tensor(["N"], "float32")):
        N = T.int64()
        if N % 16 == 0:
            out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info])
        else:
            out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info])
        return out

    @R.function(private=True)
    def unsugared(x: R.Tensor(["N"], "float32")):
        N = T.int64()
        if R.prim_value(N % 16 == 0):
            out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info])
        else:
            out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info])
        return out

    tvm.ir.assert_structural_equal(unsugared, sugared)


def test_scalar_tensor_as_assert_condition():
    """Branch condition can be 0-d tensor"""

    @R.function(pure=False)
    def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")):
        _ = R.assert_op(cond)
        out = R.add(x, x)
        return out

    assert_op = func.body.blocks[0].bindings[0].value
    condition = assert_op.args[0]
    assert isinstance(condition, relax.Var)
    tvm.ir.assert_structural_equal(condition.struct_info, R.Tensor([], "bool"))


def test_prim_value_as_assert_condition():
    """In addition to scalar tensor, can use R.Prim condition"""

    @R.function(pure=False)
    def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")):
        _ = R.assert_op(cond)
        out = R.add(x, x)
        return out

    assert_op = func.body.blocks[0].bindings[0].value
    condition = assert_op.args[0]
    assert isinstance(condition, relax.Var)
    tvm.ir.assert_structural_equal(condition.struct_info, R.Prim("bool"))


def test_computed_prim_value_as_assert_condition():
    """The R.Prim condition may be computed within the function"""

    @R.function(pure=False)
    def func(x: R.Tensor(["N"], "float32")):
        N = T.int64()
        _ = R.assert_op(R.prim_value(N % 16 == 0))
        out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info])
        return out

    N = func.params[0].struct_info.shape[0]
    assert_op = func.body.blocks[0].bindings[0].value
    condition = assert_op.args[0]
    assert isinstance(condition, relax.PrimValue)
    tvm.ir.assert_structural_equal(N % 16 == 0, condition.value)
    tvm.ir.assert_structural_equal(condition.struct_info, R.Prim(value=N % 16 == 0))


def test_tir_expr_as_assert_condition():
    """Syntactic sugar, wrap PrimExpr as PrimValue"""

    @R.function(pure=False, private=True)
    def sugared(x: R.Tensor(["N"], "float32")):
        N = T.int64()
        _ = R.assert_op(N % 16 == 0)
        out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info])
        return out

    @R.function(pure=False, private=True)
    def unsugared(x: R.Tensor(["N"], "float32")):
        N = T.int64()
        _ = R.assert_op(R.prim_value(N % 16 == 0))
        out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info])
        return out

    tvm.ir.assert_structural_equal(unsugared, sugared)


def test_erase_to_well_defined_removes_internal_vars():
    @R.function
    def foo(x: R.Tensor):
        q = x
        m, n = T.int64(), T.int64()
        z = R.match_cast(q, R.Tensor((m, n)))
        w = z
        return w

    tvm.ir.assert_structural_equal(foo.ret_struct_info, R.Tensor(ndim=2))
    assert foo.ret_struct_info.shape is None
    _check(foo)


def test_erase_to_well_defined_keeps_variables_exposed_by_tensor_shape():
    @R.function
    def foo(x: R.Tensor(["m", "n"])):
        q = x
        m, n = T.int64(), T.int64()
        z = R.match_cast(q, R.Tensor((m, n)))
        w = z
        return w

    assert foo.ret_struct_info.shape is not None
    _check(foo)


def test_erase_to_well_defined_keeps_variants_exposed_by_shape_expr():
    @R.function
    def foo(x: R.Tensor, _: R.Shape(["m", "n"])):
        q = x
        m, n = T.int64(), T.int64()
        z = R.match_cast(q, R.Tensor((m, n)))
        w = z
        return w

    assert foo.ret_struct_info.shape is not None
    _check(foo)


def test_erase_to_well_defined_keeps_variants_exposed_by_prim_value():
    @R.function
    def foo(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")):
        q = x
        m, n = T.int64(), T.int64()
        z = R.match_cast(q, R.Tensor((m, n)))
        w = z
        return w

    assert foo.ret_struct_info.shape is not None
    _check(foo)


def test_erase_to_well_defined_infers_from_shape_expr():
    @I.ir_module
    class Module:
        # The subroutine's symbolic variables are only in-scope for the subroutine.
        @R.function
        def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]):
            q = x
            m, n = T.int64(), T.int64()
            z = R.match_cast(q, R.Tensor((m, n)))
            w = z
            return w

        # However, struct inference can make the symbolic variables in
        # the main function to the symbolic variables in the
        # subroutine.  Therefore, the shape of the tensor returned
        # from main can have a well-defined shape.
        @R.function
        def main(x: R.Tensor, shape: R.Shape(["m", "n"])):
            output = Module.subroutine(x, shape)
            return output

    assert Module["main"].ret_struct_info.shape is not None
    _check(Module)


def test_erase_to_well_defined_infers_from_prim_value():
    @I.ir_module
    class Module:
        # The subroutine's symbolic variables are only in-scope for the subroutine.
        @R.function
        def subroutine(
            x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")
        ) -> R.Tensor(["m", "n"]):
            q = x
            m, n = T.int64(), T.int64()
            z = R.match_cast(q, R.Tensor((m, n)))
            w = z
            return w

        # However, struct inference can make the symbolic variables in
        # the main function to the symbolic variables in the
        # subroutine.  Therefore, the shape of the tensor returned
        # from main can have a well-defined shape.
        @R.function
        def main(x: R.Tensor, relax_m: R.Prim(value="m"), relax_n: R.Prim(value="n")):
            output = Module.subroutine(x, relax_m, relax_n)
            return output

    assert Module["main"].ret_struct_info.shape is not None
    _check(Module)


def test_empty_tuple():
    @R.function
    def foo(x: R.Tuple()):
        y: R.Tuple() = R.tuple()
        return y

    x = relax.Var("x", relax.TupleStructInfo([]))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x,)):
        y = bb.emit(relax.Tuple([]))
        bb.emit_func_output(y)

    _check(foo, bb.get()["foo"])


def test_symbolic_vars_in_tensor_shape_with_usage_first():
    """First param may use symbolic variable defined in second param"""

    @R.function
    def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")):
        z = R.add(x, y)
        return z

    m = tir.Var("m", "int64")
    x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32"))
    y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x, y)):
        z = bb.emit(relax.op.add(x, y))
        bb.emit_func_output(z)

    _check(foo, bb.get()["foo"])


def test_symbolic_vars_in_tensor_shape_with_definition_first():
    """Second param may use symbolic variable defined in first param"""

    @R.function
    def bar(
        x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32")
    ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"):
        m = T.int64()
        z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32"))
        return z

    m = tir.Var("m", "int64")
    x = relax.Var("x", relax.TensorStructInfo([m], "float32"))
    y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32"))
    bb = relax.BlockBuilder()
    with bb.function("bar", (x, y)):
        z = bb.emit(
            relax.call_dps_packed(
                "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32")
            )
        )
        bb.emit_func_output(z)

    _check(bar, bb.get()["bar"])


def test_symbolic_vars_in_shape():
    """Symbolic variable may be defined in R.Shape"""

    @R.function
    def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")):
        m = T.int64()
        z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32"))
        return z

    m = tir.Var("m", "int64")
    x = relax.Var("x", relax.ShapeStructInfo([m]))
    y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32"))
    bb = relax.BlockBuilder()
    with bb.function("baz", (x, y)):
        z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32")))
        bb.emit_func_output(z)

    _check(baz, bb.get()["baz"])


def test_symbolic_vars_in_prim_value():
    """Symbolic variable may be defined in R.Prim"""

    @R.function
    def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")):
        m = T.int64()
        z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32"))
        return z

    m = tir.Var("m", "int64")
    x = relax.Var("x", relax.PrimStructInfo(value=m))
    y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32"))
    bb = relax.BlockBuilder()
    with bb.function("baz", (x, y)):
        z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32")))
        bb.emit_func_output(z)

    _check(baz, bb.get()["baz"])


def test_undefined_symbolic_var_raises_error():
    """An undefined symbolic variable in an error

    A symbolic variables is defined at the first site where it appears
    as a shape parameter without any modification.  TVMScript does not
    support solving for a symbolic variable in terms of the argument
    shape.  That is, this test case raises an error, and will not
    attempt to define `m` as either `x.shape[0]-1` or `x.shape[1]//2`.
    """
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function
        def foo(x: R.Tensor(("m + 1", "m * 2"), "float32")):  # name 'm' is not defined
            z = R.add(x, x)
            return z


def test_arith_operators():
    @R.function
    def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")):
        a0 = -x
        a1 = x + y
        a2 = x - y
        a3 = x * y
        a4 = x / y
        a5 = x // y
        a6 = x**y

        c0 = x > y
        c1 = x < y
        c2 = x >= y
        c3 = x <= y

        tuple_expr = ((x, x), y)
        t0 = tuple_expr[0]
        t1 = tuple_expr[1]
        t2 = tuple_expr[0][0]  # <= Will normalize to two bindings
        return (a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2)

    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")
    x = relax.Var("x", relax.TensorStructInfo([m, n], "float32"))
    y = relax.Var("y", relax.TensorStructInfo([m, n], "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x, y)):
        a0 = bb.emit(relax.op.negative(x))
        a1 = bb.emit(relax.op.add(x, y))
        a2 = bb.emit(relax.op.subtract(x, y))
        a3 = bb.emit(relax.op.multiply(x, y))
        a4 = bb.emit(relax.op.divide(x, y))
        a5 = bb.emit(relax.op.floor_divide(x, y))
        a6 = bb.emit(relax.op.power(x, y))

        c0 = bb.emit(relax.op.greater(x, y))
        c1 = bb.emit(relax.op.less(x, y))
        c2 = bb.emit(relax.op.greater_equal(x, y))
        c3 = bb.emit(relax.op.less_equal(x, y))

        tuple_expr = bb.emit(relax.Tuple((relax.Tuple((x, x)), y)))
        t0 = bb.emit(relax.TupleGetItem(tuple_expr, 0))
        t1 = bb.emit(relax.TupleGetItem(tuple_expr, 1))
        tmp = bb.emit(relax.TupleGetItem(tuple_expr, 0))
        t2 = bb.emit(relax.TupleGetItem(tmp, 0))
        bb.emit_func_output(relax.Tuple((a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2)))

    _check(foo, bb.get()["foo"])


def test_memory_ops():
    @R.function
    def foo(x: R.Tensor(("m", "n"), dtype="float32")):
        m = T.int64()
        n = T.int64()
        storage = R.memory.alloc_storage(
            R.shape([4 * m * n]), virtual_device_index=0, storage_scope="global", dtype="float32"
        )
        alloc = R.memory.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), dtype="float32")
        tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0)
        gv = tensor
        return alloc, gv

    _check(foo)


def test_vm_ops():
    @R.function(pure=False)
    def foo(x: R.Tensor(("m", "n"), dtype="float32")):
        m = T.int64()
        n = T.int64()
        storage = R.vm.alloc_storage(R.shape([4 * m * n]), runtime_device_index=0, dtype="uint8")
        alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([m, n]), dtype="float32")
        tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0)
        tir_dym = R.vm.call_tir_dyn("te_func", (x, tensor, R.ShapeExpr((m, n))))
        return alloc, tir_dym

    _check(foo)


def test_builtin_ops():
    @R.function
    def foo(x: R.Tensor(("m", "n"), dtype="float32")):
        tensor = R.builtin.stop_lift_params(x)
        gv = tensor
        return gv

    _check(foo)


def test_prim_value():
    @R.function(pure=False)
    def foo():
        gv = R.call_packed("test", 1, sinfo_args=R.Tensor((32, 32), "float32"))
        return gv

    _check(foo)


def test_string_imm():
    @R.function(pure=False)
    def foo():
        gv = R.call_packed("test", "hello", sinfo_args=R.Tensor((32, 32), "float32"))
        return gv

    _check(foo)


def test_datatype_imm():
    @R.function(pure=False)
    def foo():
        gv = R.call_packed("test", R.dtype("float32"), sinfo_args=R.Tensor((32, 32), "float32"))
        return gv

    _check(foo)


def test_function_void_return_type():
    @tvm.script.ir_module
    class Foo:
        @R.function
        def main(x: R.Tensor((3, 3), dtype="float32")):
            res = Foo.mul(x)
            return res

        @R.function
        def mul(x: R.Tensor((3, 3), dtype="float32")):
            res = R.multiply(x, x)
            return res

    _check(Foo)
    # Since the return type of function `mul` is not annotated,
    # the function `main` regards it as a generic return type.
    assert isinstance(Foo["main"].ret_struct_info, relax.ObjectStructInfo)
    assert isinstance(Foo["mul"].ret_struct_info, relax.TensorStructInfo)

    @tvm.script.ir_module
    class Bar:
        @R.function
        def main(x1: R.Tensor((3, 3), dtype="float32")):
            res1 = Bar.mul(x1)
            return res1

        @R.function
        def mul(x: R.Tensor((3, 3), dtype="float32")) -> None:
            res = R.multiply(x, x)
            return res

    # Since the return type of function `mul` is not annotated,
    # the function `main` regards it as a generic return type.
    _check(Bar)
    tvm.ir.assert_structural_equal(Bar["main"].ret_struct_info, relax.TupleStructInfo([]))
    tvm.ir.assert_structural_equal(Bar["mul"].ret_struct_info, relax.TupleStructInfo([]))


def test_class_normalize():
    @tvm.script.ir_module
    class InputModule:
        @R.function
        def mul_add(x: R.Tensor) -> R.Tensor:
            return R.multiply(R.add(x, x), R.add(x, x))

    # The parser automatically normalizes the input AST to the following ANF form
    @tvm.script.ir_module
    class OutputModule:
        @R.function
        def mul_add(x: R.Tensor) -> R.Tensor:
            gv = R.add(x, x)
            gv1 = R.add(x, x)
            return R.multiply(gv, gv1)

    _check(InputModule, OutputModule)


def test_context_aware_parsing(monkeypatch):
    @tvm.script.ir_module
    class Module:
        @T.prim_func
        def add(
            X: T.Buffer([T.int64(2), T.int64(4)], "float32"),
            Y: T.Buffer((), "float32"),
            Z: T.Buffer([T.int64(2), T.int64(4)], "float32"),
        ):
            T.evaluate(0)

        @R.function
        def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
            R.func_attr({"relax.force_pure": True})
            cls = Module
            alloc = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0)
            _: R.Tuple() = cls.add(x, R.const(1, "float32"), alloc)
            return alloc

    _check(Module)

    # Break the env settings, but context-aware parsing can still handle it
    def _break_env(self, *args):
        raise RuntimeError("Fail to pass context-aware parsing")

    monkeypatch.setattr(tvm.ir.GlobalVar, "__call__", _break_env)

    _check(Module)


def test_unit_tuple_on_rhs_of_assign():
    @I.ir_module
    class Module:
        @R.function
        def main(input: R.Tensor((5, 5))) -> R.Tuple(R.Tensor((5, 5))):
            gv = (input,)
            return gv

    _check(Module)


def test_empty_tuple_on_rhs_of_assign():
    @I.ir_module
    class Module:
        @R.function
        def main(input: R.Tensor((5, 5))) -> R.Tuple():
            gv = ()
            return gv

    _check(Module)


def test_global_var_sinfo():
    @I.ir_module
    class Module:
        @R.function
        def foo(x: R.Tensor((128, 128), "float32")):
            gv0 = R.emit_te(topi.add, x, x)
            return gv0

    target_sinfo = R.Callable(
        (R.Tensor((128, 128), dtype="float32"),), R.Tensor((128, 128), dtype="float32")
    )
    gv = Module.get_global_var("foo")
    tvm.ir.assert_structural_equal(gv.struct_info, target_sinfo)
    tvm.ir.assert_structural_equal(Module["foo"].struct_info, target_sinfo)
    _check(Module)


def test_assert_op():
    @I.ir_module
    class AssertOp:
        @R.function(pure=False)
        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
            return x

    _check(AssertOp)


def test_assert_outside_of_class():
    @R.function(pure=False)
    def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
        y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
        return x

    # this just makes sure that the machinery regarding the pure attribute parses
    # in the case where the function is outside of a class too
    _check(func)


def test_impure_inner_function():
    @R.function
    def f(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
        # we will not actually call it
        @R.function(pure=False)
        def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            z = R.assert_op(R.const(False, dtype="bool"), y, format="y: {}")
            return y

        return x

    assert f.is_pure
    # definition of g
    assert not f.body.blocks[0].bindings[0].value.is_pure

    # make sure we are not incorrectly passing state for inner functions
    _check(f)


def test_impure_inner_function_in_class():
    @I.ir_module
    class ImpureInner:
        @R.function
        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            # we will not actually call it
            @R.function(pure=False)
            def g(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
                z = R.assert_op(R.const(False, dtype="bool"), y, format="y: {}")
                return y

            return x

    assert ImpureInner["main"].is_pure
    # definition of g
    assert not ImpureInner["main"].body.blocks[0].bindings[0].value.is_pure

    # make sure we are not incorrectly passing state for inner functions
    _check(ImpureInner)


def test_print():
    @I.ir_module
    class Print:
        @R.function(pure=False)
        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            y = R.print(x, format="x: {}")
            return x

    _check(Print)


def test_parse_multiple_pure_and_impure_funcs():
    @I.ir_module
    class Mixture:
        @R.function(pure=False)
        def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            y = R.print(x, format="x: {}")
            return x

        @R.function(pure=False)
        def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
            return x

        @R.function
        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            return x

    assert not Mixture["print"].is_pure
    assert not Mixture["assert_func"].is_pure
    assert Mixture["main"].is_pure
    _check(Mixture)


def test_function_with_void_return_type_may_be_used_as_statements():
    """Void return of calls do not need to be assigned"""

    @I.ir_module
    class Unsugared:
        @R.function(pure=False)
        def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            y = R.print(x, format="x: {}")
            return x

        @R.function(pure=False)
        def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
            return x

    @I.ir_module
    class Sugared:
        @R.function(pure=False)
        def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            R.print(x, format="x: {}")
            return x

        @R.function(pure=False)
        def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
            return x

    tvm.ir.assert_structural_equal(Unsugared, Sugared)


def test_function_with_non_void_return_type_must_be_assigned():
    """Non-void results must be assigned to a variable"""

    with pytest.raises(tvm.error.DiagnosticError):

        @R.function(pure=False)
        def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            R.add(x, x)
            return x


def test_function_with_void_return_type_in_if_else():
    """Last statement in if/else may be a void return"""

    @I.ir_module
    class Unsugared:
        @R.function(pure=False)
        def conditional(
            x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")
        ) -> R.Tensor((), "int32"):
            if condition:
                y = R.print(x, format="True condition: {}")
            else:
                y = R.print(x, format="False condition: {}")
            return x

    @I.ir_module
    class Sugared:
        @R.function(pure=False)
        def conditional(
            x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")
        ) -> R.Tensor((), "int32"):
            if condition:
                R.print(x, format="True condition: {}")
            else:
                R.print(x, format="False condition: {}")
            return x

    _check(Sugared, Unsugared)


def test_call_pure_packed():
    @R.function
    def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        z = R.call_pure_packed("vm.builtin.copy", x, sinfo_args=R.Tensor((32, 32), "float32"))
        return z

    x = relax.Var("x", R.Tensor((32, 32), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("foo", (x)):
        z = bb.emit(
            R.call_pure_packed("vm.builtin.copy", x, sinfo_args=[R.Tensor((32, 32), "float32")])
        )
        bb.emit_func_output(z)

    _check(foo, bb.get()["foo"])


def test_call_pure_packed_returning_object():
    @R.function
    def foo() -> R.Object:
        z = R.call_pure_packed("dummy_func", sinfo_args=R.Object)
        return z

    bb = relax.BlockBuilder()
    with bb.function("foo", params=[]):
        z = bb.emit(R.call_pure_packed("dummy_func", sinfo_args=[relax.ObjectStructInfo()]))
        bb.emit_func_output(z)

    _check(foo, bb.get()["foo"])


def test_private_function():
    @I.ir_module
    class Addition:
        @R.function(private=True)
        def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            y = R.add(x, x)
            return y

    x = relax.Var("x", R.Tensor((), "int32"))
    bb = relax.BlockBuilder()
    with bb.function("main", (x), private=True):
        y = bb.emit(R.add(x, x))
        bb.emit_func_output(y)

    _check(Addition, bb.get())


def test_private_function_with_global_symbol_fail():
    with pytest.raises(tvm.error.DiagnosticError):

        @I.ir_module
        class Addition:
            @R.function(private=True)
            def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
                # it is an error to simultaneously mark a function private
                # and give it a global symbol manually
                R.func_attr({"global_symbol": "main"})
                y = R.add(x, x)
                return y

        # should not execute
        _check(Addition)


def test_private_function_with_global_symbol_no_module_fail():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.function(private=True)
        def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
            R.func_attr({"global_symbol": "main"})
            y = R.add(x, x)
            return y

        # should not execute
        _check(func)


def test_macro_hygienic():
    x = R.prim_value(2)

    @R.macro(hygienic=True)
    def alloc_and_shape(dtype: str):
        alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=x, dtype=dtype)
        shape = R.shape_of(alloc)
        return shape

    x = R.prim_value(1)

    @R.function(private=True)
    def func(z: R.Tensor((4, 4), "float32")):
        shape = alloc_and_shape(dtype="float32")
        return shape

    @R.function(private=True)
    def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
        alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
            R.shape([4, 4]),
            R.dtype("float32"),
            R.prim_value(2),  # Make sure prim_value is 2
        )
        shape: R.Shape([4, 4]) = R.shape_of(alloc)
        shape_1: R.Shape([4, 4]) = shape
        return shape_1

    _check(func, expect)


def test_macro_non_hygienic():
    global global_x_var  # Lookup doesn't find this variable if it's not global

    global_x_var = R.prim_value(2)

    @R.macro(hygienic=False)
    def alloc_and_shape(dtype: str):
        alloc = R.builtin.alloc_tensor(
            R.shape([4, 4]), runtime_device_index=global_x_var, dtype=dtype
        )
        shape = R.shape_of(alloc)
        return shape

    global_x_var = R.prim_value(1)

    @R.function(private=True)
    def func(z: R.Tensor((4, 4), "float32")):
        shape = alloc_and_shape(dtype="float32")
        return shape

    @R.function(private=True)
    def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
        alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
            R.shape([4, 4]),
            R.dtype("float32"),
            R.prim_value(1),  # Make sure prim_value is 1
        )
        shape: R.Shape([4, 4]) = R.shape_of(alloc)
        shape_1: R.Shape([4, 4]) = shape
        return shape_1

    _check(func, expect)


def test_macro_no_variable_leak():
    with pytest.raises(tvm.error.DiagnosticError):

        @R.macro(hygienic=True)
        def add_two(value):
            x = value + R.const(1)  # `x` defined in macro
            y = x + R.const(1)
            return y

        @R.function(private=True)
        def func(t: R.Tensor((), "int32")):
            u = add_two(t)
            return x  # Should be undefined here


def test_reused_extern_func():
    """ExternFunc lookups can become bindings in EliminateCommonSubexpr"""

    @R.function(private=True)
    def parsed(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
        func = R.ExternFunc("extern_func")
        gv0 = R.call_dps_packed(func, x, R.Tensor((128, 128), dtype="float32"))
        gv1 = R.call_dps_packed(func, gv0, R.Tensor((128, 128), dtype="float32"))
        return gv1

    x = relax.Var("x", R.Tensor((128, 128), "float32"))
    bb = relax.BlockBuilder()
    with bb.function("main", [x], private=True):
        func = bb.emit(relax.ExternFunc("extern_func"))
        y = bb.emit(relax.call_dps_packed(func, x, out_sinfo=R.Tensor((128, 128), "float32")))
        z = bb.emit(relax.call_dps_packed(func, y, out_sinfo=R.Tensor((128, 128), "float32")))
        bb.emit_func_output(z)

    expected = bb.get()["main"]

    _check(parsed, expected)


def test_extern_func_in_module():
    """Module-level parsing may produce function bindings"""

    @I.ir_module
    class parsed_module:
        my_ext = R.ExternFunc("my_ext")

        @R.function
        def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
            return a

    @R.function
    def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
        return a

    expected = tvm.IRModule({"my_ext": relax.ExternFunc("my_ext"), "func": func})

    _check(parsed_module, expected)


def test_define_relax_function_using_global_var():
    """A @R.function may call a GlobalVar

    When parsing a @R.function, the function's body may reference
    GlobalVar instances available in the calling python scope.  The
    resulting function should pass TVMScript's well-formed check, as
    the GlobalVar may be available in the IRModule for which the
    function is being defined.
    """

    @I.ir_module
    class DefinedAllAtOnce:
        @R.function
        def main(A: R.Tensor, B: R.Tensor):
            return DefinedAllAtOnce.subroutine(A, B)

        @R.function(private=True)
        def subroutine(A: R.Tensor, B: R.Tensor) -> R.Tensor:
            return R.matmul(A, B)

    @I.ir_module
    class MainDefinedLater:
        @R.function(private=True)
        def subroutine(A: R.Tensor, B: R.Tensor) -> R.Tensor:
            return R.matmul(A, B)

    subroutine_gvar = MainDefinedLater.get_global_var("subroutine")

    @R.function
    def main(A: R.Tensor, B: R.Tensor):
        return subroutine_gvar(A, B)

    MainDefinedLater["main"] = main

    tvm.ir.assert_structural_equal(DefinedAllAtOnce, MainDefinedLater)


def test_function_attributes_are_defined():
    """func.attrs defaults to an empty DictAttrs"""

    @I.ir_module
    class Module:
        @R.function
        def main(x: R.Tensor, shape: R.Shape(["m", "n"])):
            output = Module.subroutine(x, shape)
            return output

        @R.function
        def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]):
            q = x
            m, n = T.int64(), T.int64()
            z = R.match_cast(q, R.Tensor((m, n)))
            w = z
            return w

    for gvar, func in Module.functions.items():
        assert func.attrs is not None


@pytest.mark.xfail(reason="Bug: Implicit bounds not provided when parsing")
def test_function_symbolic_variables_are_annotated():
    """Symbolic variables must be exposed for struct inference

    Because Relax struct inference is performed while the function is
    being built, all constraints on symbolic variables that are used
    for simplifications must be provided to the analyzer.
    """

    @R.function(private=True)
    def inferred_sinfo(A: R.Tensor(["extent"])):
        extent = T.int64()
        output = R.strided_slice(A, [0], [0], [extent - 1])
        return output

    @R.function(private=True)
    def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]):
        extent = T.int64()
        output: R.Tensor([extent - 1]) = R.strided_slice(A, [0], [0], [extent - 1])
        return output

    tvm.ir.assert_structural_equal(inferred_sinfo, expected)


def test_conditional_may_use_symbolic_variables_from_function_scope():
    """Symbolic variables from function scope may be used in branch

    This is a regression test.  In earlier implementations, the
    branches of `relax::If` were normalized with
    `EraseToWellDefinedInScope`, using a fresh variable scope.  While
    this had the intended behavior of preventing variables defined in
    a single branch from being usable outside of the conditional, it
    also caused the conditional's branches to treat function-scope
    symbolic variables as if they were undefined.

    """

    @R.function(private=True)
    def explicit_sinfo(
        A: R.Tensor(["N"], "float32"),
        B: R.Tensor(["N"], "float32"),
        cond: R.Prim("bool"),
    ) -> R.Tensor(["N"], "float32"):
        N = T.int64()

        if cond:
            out: R.Tensor([N], "float32") = A + B
        else:
            out: R.Tensor([N], "float32") = A * B

        return out

    @R.function(private=True)
    def inferred_sinfo(
        A: R.Tensor(["N"], "float32"),
        B: R.Tensor(["N"], "float32"),
        cond: R.Prim("bool"),
    ):
        N = T.int64()
        if cond:
            out = A + B
        else:
            out = A * B

        return out

    tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo)


def test_return_from_dataflow_block():
    """Return statements imply

    The `R.output` statement in a `R.dataflow()` block marks a
    variable that should be a `relax.Var` instead of a
    `relax.DataflowVar`, allowing it to be used outside of the
    `DataflowBlock` that defined it.  A relax function's output is not
    part of any binding, and must not contain any `DataflowVar`, so
    these are exposed implicitly.

    """

    @R.function(private=True)
    def output_then_return(A: R.Tensor([16], "float16")):
        with R.dataflow():
            B = R.add(A, A)
            C = R.multiply(B, B)
            R.output(C)

        return C

    @R.function(private=True)
    def return_inside_dataflow(A: R.Tensor([16], "float16")):
        with R.dataflow():
            B = R.add(A, A)
            C = R.multiply(B, B)
            return C

    tvm.ir.assert_structural_equal(output_then_return, return_inside_dataflow)


if __name__ == "__main__":
    tvm.testing.main()
